import dgl
import numpy as np
import torch
import argparse
import time
from load_dataset import load_ogb, load_reddit, load_citeseer, load_flickr
from dgl.data.utils import save_graphs
import os

def metis_partition_dataset(graph, num_parts, balance_ntypes):
    """
    Partition the graph dataset into num_parts parts based on the METIS partitioning algorithm.
    
    This function takes an input graph and partitions it into num_parts parts using the METIS partitioning algorithm.
    We balance the number of training nodes in each partition.
    
    Args:
        graph (dgl.DGLGraph): The input graph.
        num_parts (int): The number of parts to partition the graph into.
        
    Returns:
        list: A list of subgraphs representing the partitioned dataset.
    """
    partitions = dgl.metis_partition(graph, num_parts, balance_ntypes=graph.ndata['train_mask'])

    subgraphs = []
    for part_id in partitions:
        part = partitions[part_id]

        # Get all nodes in the partition
        nodes = part.ndata[dgl.NID]

        # Induce the subgraph based on the nodes from graph
        subgraph = dgl.node_subgraph(graph, nodes)
        subgraphs.append(subgraph)

    return subgraphs

def random_partition(graph, num_parts):
    """
    Partition the graph dataset into num_parts parts randomly.
    
    This function takes an input graph and partitions it into num_parts parts randomly.
    
    Args:
        graph (dgl.DGLGraph): The input graph.
        num_parts (int): The number of parts to partition the graph into.
        
    Returns:
        list: A list of subgraphs representing the partitioned dataset.
    """
    subgraphs = []
    num_nodes = graph.number_of_nodes()
    nodes = np.arange(num_nodes)
    np.random.shuffle(nodes)
    part_size = num_nodes // num_parts
    for i in range(num_parts):
        start = i * part_size
        end = (i + 1) * part_size
        if i == num_parts - 1:
            end = num_nodes
        subgraph = dgl.node_subgraph(graph, nodes[start:end])
        subgraphs.append(subgraph)
    return subgraphs

if __name__ == '__main__':
    argparser = argparse.ArgumentParser("Partition builtin graphs")
    argparser.add_argument('--dataset', type=str, default='ogbn-arxiv',
                           help='datasets: ogbn-arxiv, ogbn-products, reddit, citeseer, flickr')
    argparser.add_argument('--num_parts', type=int, default=5,
                           help='number of partitions')
    argparser.add_argument('--balance_train', action='store_true',
                           help='balance the training size in each partition.')
    argparser.add_argument('--balance_edges', action='store_true',
                           help='balance the number of edges in each partition.')
    argparser.add_argument('--partition_method', type=str, default='metis',
                           help='partition method: metis')
    args = argparser.parse_args()

    output_dir = "partitioned_dataset"
    saving_graph_name = args.dataset + "_" + args.partition_method + "_" + str(args.num_parts)

    # Check if the output directory exists
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Check if the partitioned dataset already exists
    if os.path.isfile(output_dir + "/" + saving_graph_name + ".bin"):
        print("Partitioned dataset already exists. Exiting...")
        exit(0)

    start = time.time()
    if args.dataset == 'ogbn-products':
        g, _ = load_ogb('ogbn-products')
    elif args.dataset == 'ogbn-arxiv':
        g, _ = load_ogb('ogbn-arxiv')
    elif args.dataset == 'reddit':
        g, _ = load_reddit()
    elif args.dataset == 'citeseer':
        g, _ = load_citeseer()
    elif args.dataset == 'flickr':
        g, _ = load_flickr()
    else:
        raise ValueError('Unknown dataset: {}'.format(args.dataset))
    print('load {} takes {:.3f} seconds'.format(args.dataset, time.time() - start))
    print('|V|={}, |E|={}'.format(g.number_of_nodes(), g.number_of_edges()))
    print('train: {}, valid: {}, test: {}'.format(torch.sum(g.ndata['train_mask']),
                                                  torch.sum(g.ndata['val_mask']),
                                                  torch.sum(g.ndata['test_mask'])))

    if args.balance_train:
        balance_ntypes = g.ndata['train_mask']
    else:
        balance_ntypes = None

    if args.partition_method == 'metis':
        subgraphs = metis_partition_dataset(g, args.num_parts, balance_ntypes)
        print('partitioning takes {:.3f} seconds'.format(time.time() - start))

        # save the partitioned dataset in output_dir
        print('Saving partitioned dataset to {}/{}.bin'.format(output_dir, saving_graph_name))
        save_graphs('{}/{}.bin'.format(output_dir, saving_graph_name), subgraphs)
    else:
        raise ValueError('Unknown partition method: {}'.format(args.partition_method))

    